/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "../../../../../../kernel_impl/broadcast_custom.h"

extern "C" __global__ __aicore__ void broadcast_custom(GM_ADDR x, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);
    GET_TILING_DATA(tilingData, tiling);
    uint32_t axis = tilingData.axis;
    uint32_t dim = tilingData.dim;
    if (TILING_KEY_IS(1)) {
        if (dim == 1) {
            const uint32_t srcShape[] = {tilingData.srcFirstDim};
            const uint32_t dstShape[] = {tilingData.dstFirstDim};
            KernelBroadcastCustom<float, 1, 0> op;
            op.Init(x, y, tilingData.srcFirstDim, tilingData.dstFirstDim, srcShape, dstShape);
            op.Process();
        } else {
            const uint32_t srcShape[] = {tilingData.srcFirstDim, tilingData.srcLastDim};
            const uint32_t dstShape[] = {tilingData.dstFirstDim, tilingData.dstLastDim};

            if (axis == 0) {
                KernelBroadcastCustom<float, 2, 0> op;
                op.Init(x,
                    y,
                    tilingData.srcFirstDim * tilingData.srcLastDim,
                    tilingData.dstFirstDim * tilingData.dstLastDim,
                    srcShape,
                    dstShape);
                op.Process();
            } else {
                KernelBroadcastCustom<float, 2, 1> op;
                op.Init(x,
                    y,
                    tilingData.srcFirstDim * tilingData.srcLastDim,
                    tilingData.dstFirstDim * tilingData.dstLastDim,
                    srcShape,
                    dstShape);
                op.Process();
            }
        }
    }
}